from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import sys


class opts(object):
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        # basic experiment setting
        self.parser.add_argument(
            "task", default="single_pose", help="single_pose | multi_pose"
        )
        self.parser.add_argument(
            "--dataset", default="active", help="coco_hp | active | active_coco"
        )
        self.parser.add_argument("--exp_id", default="default")
        self.parser.add_argument("--test", action="store_true")
        self.parser.add_argument(
            "--debug",
            type=int,
            default=0,
            help="level of visualization."
            "1: only show the final detection results"
            "2: show the network output features"
            "3: use matplot to display"  # useful when lunching training with ipython notebook
            "4: save all visualizations to disk",
        )
        self.parser.add_argument(
            "--demo",
            default="",
            help="path to image/ image folders/ video. " 'or "webcam"',
        )
        self.parser.add_argument(
            "--load_model", default="", help="path to pretrained model"
        )
        self.parser.add_argument(
            "--resume",
            action="store_true",
            help="resume an experiment. "
            "Reloaded the optimizer parameter and "
            "set load_model to model_last.pth "
            "in the exp dir if load_model is empty.",
        )

        # system
        self.parser.add_argument(
            "--gpus", default="0", help="-1 for CPU, use comma for multiple gpus"
        )
        self.parser.add_argument(
            "--num_workers",
            type=int,
            default=4,
            help="dataloader threads. 0 for single-thread.",
        )
        self.parser.add_argument(
            "--not_cuda_benchmark",
            action="store_true",
            help="disable when the input size is not fixed.",
        )
        self.parser.add_argument(
            "--seed", type=int, default=317, help="random seed"
        )  # from CornerNet

        # log
        self.parser.add_argument(
            "--print_iter",
            type=int,
            default=0,
            help="disable progress bar and print to screen.",
        )
        self.parser.add_argument(
            "--hide_data_time",
            action="store_true",
            help="not display time during training.",
        )
        self.parser.add_argument(
            "--save_all", action="store_true", help="save model to disk every 5 epochs."
        )
        self.parser.add_argument(
            "--metric", default="loss", help="main metric to save best model"
        )
        self.parser.add_argument(
            "--vis_thresh", type=float, default=0.3, help="visualization threshold."
        )
        self.parser.add_argument(
            "--debugger_theme", default="white", choices=["white", "black"]
        )

        # model
        self.parser.add_argument(
            "--arch",
            default="movenet",
            help="model architecture. Currently tested" "movenet",
        )
        self.parser.add_argument(
            "--froze_backbone",
            default=True,
            action="store_false",
            help="whether froze the backbone during training",
        )
        self.parser.add_argument(
            "--head_conv",
            type=int,
            default=-1,
            help="conv layer channels for output head"
            "0 for no conv layer"
            "-1 for default setting: "
            "64 for resnets, 256 for dla and 256 for movenet.",
        )
        self.parser.add_argument(
            "--down_ratio",
            type=int,
            default=4,
            help="output stride. Currently only supports 4.",
        )

        # input
        self.parser.add_argument(
            "--input_res",
            type=int,
            default=-1,
            help="input height and width. -1 for default from "
            "dataset. Will be overriden by input_h | input_w",
        )
        self.parser.add_argument(
            "--input_h",
            type=int,
            default=-1,
            help="input height. -1 for default from dataset.",
        )
        self.parser.add_argument(
            "--input_w",
            type=int,
            default=-1,
            help="input width. -1 for default from dataset.",
        )

        # train
        self.parser.add_argument(
            "--lr", type=float, default=1.25e-4, help="learning rate for batch size 32."
        )
        self.parser.add_argument(
            "--lr_step", type=str, default="90,120", help="drop learning rate by 10."
        )
        self.parser.add_argument(
            "--num_epochs", type=int, default=140, help="total training epochs."
        )
        self.parser.add_argument(
            "--batch_size", type=int, default=32, help="batch size"
        )
        self.parser.add_argument(
            "--master_batch_size",
            type=int,
            default=-1,
            help="batch size on the master gpu.",
        )
        self.parser.add_argument(
            "--num_iters", type=int, default=-1, help="default: #samples / batch_size."
        )
        self.parser.add_argument(
            "--val_intervals",
            type=int,
            default=5,
            help="number of epochs to run validation.",
        )
        self.parser.add_argument(
            "--trainval",
            action="store_true",
            help="include validation in training and " "test on test set",
        )

        # test
        self.parser.add_argument(
            "--flip_test", action="store_true", help="flip data augmentation."
        )
        self.parser.add_argument(
            "--nms", action="store_true", help="run nms in testing."
        )
        self.parser.add_argument(
            "--K", type=int, default=100, help="max number of output objects."
        )

        # dataset
        self.parser.add_argument(
            "--not_rand_crop",
            action="store_true",
            help="not use the random crop data augmentation" "from CornerNet.",
        )
        self.parser.add_argument(
            "--shift",
            type=float,
            default=0.1,
            help="when not using random crop" "apply shift augmentation.",
        )
        self.parser.add_argument(
            "--scale",
            type=float,
            default=0.4,
            help="when not using random crop" "apply scale augmentation.",
        )
        self.parser.add_argument(
            "--rotate",
            type=float,
            default=0,
            help="when not using random crop" "apply rotation augmentation.",
        )
        self.parser.add_argument(
            "--flip",
            type=float,
            default=0.5,
            help="probability of applying flip augmentation.",
        )
        self.parser.add_argument(
            "--no_color_aug",
            action="store_true",
            help="not use the color augmenation " "from CornerNet",
        )
        # multi_pose
        self.parser.add_argument(
            "--aug_rot",
            type=float,
            default=0,
            help="probability of applying " "rotation augmentation.",
        )

        # loss
        self.parser.add_argument(
            "--mse_loss",
            action="store_true",
            help="use mse loss or focal loss to train " "keypoint heatmaps.",
        )
        # # ctdet
        self.parser.add_argument(
            "--reg_loss", default="l1", help="regression loss: sl1 | l1 | l2"
        )
        self.parser.add_argument(
            "--hm_weight",
            type=float,
            default=1,
            help="loss weight for keypoint heatmaps.",
        )
        self.parser.add_argument(
            "--off_weight",
            type=float,
            default=1,
            help="loss weight for keypoint local offsets.",
        )

        # multi_pose
        self.parser.add_argument(
            "--hp_weight",
            type=float,
            default=1,
            help="loss weight for human pose offset.",
        )
        self.parser.add_argument(
            "--hm_hp_weight",
            type=float,
            default=1,
            help="loss weight for human keypoint heatmap.",
        )

    def parse(self, args=""):
        if args == "":
            opt = self.parser.parse_args()
        else:
            opt = self.parser.parse_args(args)

        opt.gpus_str = opt.gpus
        opt.gpus = [int(gpu) for gpu in opt.gpus.split(",")]
        opt.gpus = [i for i in range(len(opt.gpus))] if opt.gpus[0] >= 0 else [-1]
        opt.lr_step = [int(i) for i in opt.lr_step.split(",")]

        if opt.head_conv == -1:  # init default head_conv
            if "dla" in opt.arch:
                opt.head_conv = 256
            elif "movenet" in opt.arch:
                opt.head_conv = 96
            else:
                opt.head_conv = 64
        opt.pad = 127 if "hourglass" in opt.arch else 31
        opt.num_stacks = 2 if opt.arch == "hourglass" else 1

        if opt.trainval:
            opt.val_intervals = 100000000

        if opt.debug > 0:
            opt.num_workers = 0
            opt.batch_size = 1
            opt.gpus = [opt.gpus[0]]
            opt.master_batch_size = -1

        if opt.master_batch_size == -1:
            opt.master_batch_size = opt.batch_size // len(opt.gpus)
        rest_batch_size = opt.batch_size - opt.master_batch_size
        opt.chunk_sizes = [opt.master_batch_size]
        for i in range(len(opt.gpus) - 1):
            slave_chunk_size = rest_batch_size // (len(opt.gpus) - 1)
            if i < rest_batch_size % (len(opt.gpus) - 1):
                slave_chunk_size += 1
            opt.chunk_sizes.append(slave_chunk_size)
        print("training chunk_sizes:", opt.chunk_sizes)

        opt.root_dir = os.path.join(os.path.dirname(__file__), "..", "..")
        opt.data_dir = os.path.join(opt.root_dir, "data")
        opt.exp_dir = os.path.join(opt.root_dir, "exp", opt.task)
        opt.save_dir = os.path.join(opt.exp_dir, opt.exp_id)
        opt.debug_dir = os.path.join(opt.save_dir, "debug")
        print("The output will be saved to ", opt.save_dir)

        if opt.resume and opt.load_model == "":
            model_path = (
                opt.save_dir[:-4] if opt.save_dir.endswith("TEST") else opt.save_dir
            )
            opt.load_model = os.path.join(model_path, "model_last.pth")
        return opt

    def update_dataset_info_and_set_heads(self, opt, dataset):
        input_h, input_w = dataset.default_resolution
        opt.mean, opt.std = dataset.mean, dataset.std
        opt.num_classes = dataset.num_classes

        # input_h(w): opt.input_h overrides opt.input_res overrides dataset default
        input_h = opt.input_res if opt.input_res > 0 else input_h
        input_w = opt.input_res if opt.input_res > 0 else input_w
        opt.input_h = opt.input_h if opt.input_h > 0 else input_h
        opt.input_w = opt.input_w if opt.input_w > 0 else input_w
        opt.output_h = opt.input_h // opt.down_ratio
        opt.output_w = opt.input_w // opt.down_ratio
        opt.input_res = max(opt.input_h, opt.input_w)
        opt.output_res = max(opt.output_h, opt.output_w)

        if opt.task == "multi_pose":
            assert opt.dataset in ["coco_hp"]
            opt.flip_idx = dataset.flip_idx
            opt.heads = {"hm": opt.num_classes, "wh": 2, "hps": 34}
            opt.heads.update({"reg": 2})
            opt.heads.update({"hm_hp": 17})
            opt.heads.update({"hp_offset": 2})
            raise KeyError("The multi_pose is not supported for now.")
        elif opt.task == "single_pose":
            opt.flip_idx = dataset.flip_idx
            opt.heads = {"hm": opt.num_classes, "hps": 34, "hm_hp": 17, "hp_offset": 34}
        else:
            assert 0, "task not defined!"
        print("heads", opt.heads)
        return opt

    def init(self, args=""):
        default_dataset_info = {
            "multi_pose": {
                "default_resolution": [512, 512],
                "num_classes": 1,
                "mean": [0.408, 0.447, 0.470],
                "std": [0.289, 0.274, 0.278],
                "dataset": "coco_hp",
                "num_joints": 17,
                "flip_idx": [
                    [1, 2],
                    [3, 4],
                    [5, 6],
                    [7, 8],
                    [9, 10],
                    [11, 12],
                    [13, 14],
                    [15, 16],
                ],
            },
            "single_pose": {
                "default_resolution": [512, 512],
                "num_classes": 1,
                "mean": [1.0, 1.0, 1.0],
                "std": [1.0, 1.0, 1.0],
                "dataset": "active",
                "num_joints": 17,
                "flip_idx": [
                    [1, 2],
                    [3, 4],
                    [5, 6],
                    [7, 8],
                    [9, 10],
                    [11, 12],
                    [13, 14],
                    [15, 16],
                ],
            },
        }

        class Struct:
            def __init__(self, entries):
                for k, v in entries.items():
                    self.__setattr__(k, v)

        opt = self.parse(args)
        dataset = Struct(default_dataset_info[opt.task])
        opt.dataset = dataset.dataset
        opt = self.update_dataset_info_and_set_heads(opt, dataset)
        return opt
